{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/vik/.virtualenvs/nnets/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import torch\n",
    "from torch import nn\n",
    "import os\n",
    "\n",
    "sys.path.append(os.path.abspath(\"../../data\"))\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "BATCH_SIZE = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "from csv_data import WeatherDatasetWrapper\n",
    "\n",
    "class WeatherDataset(WeatherDatasetWrapper):\n",
    "    predictors = [\"tmax\", \"tmin\", \"rain\"]\n",
    "    target = \"tmax_tomorrow\"\n",
    "    sequence_length = 7\n",
    "\n",
    "wrapper = WeatherDataset(DEVICE)\n",
    "datasets = wrapper.generate_datasets(BATCH_SIZE)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "train = datasets[\"train\"]\n",
    "valid = datasets[\"validation\"]\n",
    "test = datasets[\"test\"]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NeuralNetwork, self).__init__()\n",
    "        self.rnn = nn.RNN(3, 4, 1, batch_first=True)\n",
    "        self.dense = nn.Linear(4, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x, hidden = self.rnn(x)\n",
    "        x = self.dense(x)\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 valid loss: 24.160947932944314\n",
      "Epoch 1 valid loss: 22.97017237579486\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[5], line 16\u001B[0m\n\u001B[1;32m     13\u001B[0m     pred \u001B[38;5;241m=\u001B[39m model(sequence)\n\u001B[1;32m     15\u001B[0m     loss \u001B[38;5;241m=\u001B[39m loss_fn(pred, target)\n\u001B[0;32m---> 16\u001B[0m     \u001B[43mloss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     17\u001B[0m     optimizer\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m     19\u001B[0m losses \u001B[38;5;241m=\u001B[39m []\n",
      "File \u001B[0;32m~/.virtualenvs/nnets/lib/python3.10/site-packages/torch/_tensor.py:488\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m    478\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m    479\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m    480\u001B[0m         Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m    481\u001B[0m         (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m    486\u001B[0m         inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m    487\u001B[0m     )\n\u001B[0;32m--> 488\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    489\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m    490\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/.virtualenvs/nnets/lib/python3.10/site-packages/torch/autograd/__init__.py:197\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m    192\u001B[0m     retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m    194\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m    195\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m    196\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 197\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m  \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m    198\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    199\u001B[0m \u001B[43m    \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "from statistics import mean\n",
    "\n",
    "model = NeuralNetwork().to(DEVICE)\n",
    "loss_fn = nn.MSELoss()\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)\n",
    "\n",
    "EPOCHS = 50\n",
    "for epoch in range(EPOCHS):\n",
    "    for batch, (sequence, target) in enumerate(train):\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        sequence = sequence.to(DEVICE)\n",
    "        pred = model(sequence)\n",
    "\n",
    "        loss = loss_fn(pred, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    losses = []\n",
    "    with torch.no_grad():\n",
    "        for batch, (sequence, target) in enumerate(valid):\n",
    "            sequence = sequence.to(DEVICE)\n",
    "            pred = model(sequence)\n",
    "            loss = loss_fn(pred, target)\n",
    "            losses.append(loss.item())\n",
    "\n",
    "    print(f\"Epoch {epoch} valid loss: {mean(losses)}\")"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
